import copy
from typing import Callable, Dict, Tuple

import torch
import torch.nn as nn
from torch import Tensor
import numpy as np

from torch_geometric.nn import TransformerConv
from torch_geometric.nn.inits import zeros
from torch_geometric.utils import scatter

TGNMessageStoreType = Dict[int, Tuple[Tensor, Tensor, Tensor, Tensor]]


class TGN(nn.Module):
    def __init__(
            self,
            num_nodes: int,
            msg_dim: int,
            model_dim: int = 100,
            num_neighbors: int = 10,
            num_gnn_layers: int = 2,
            time_encoding_version: str = "learn",
            aggregator_version: str = "last",
            readout_version: int = 0,
            device=None,
            storage=None,
            **kwargs
    ):
        super().__init__()

        self.num_nodes = num_nodes
        self.msg_dim = msg_dim
        self.memory_dim, self.time_dim, self.embedding_dim = model_dim, model_dim, model_dim
        self.num_neighbors = num_neighbors
        self.time_encoding_version = time_encoding_version
        self.aggregator_version = aggregator_version
        self.readout_version = readout_version
        self.device = device
        self.storage = storage.to(device)

        self.criterion = nn.BCEWithLogitsLoss()

        self.memory = TGNMemory(
            self.num_nodes,
            self.msg_dim,
            self.memory_dim,
            self.time_dim,
            message_module=IdentityMessage(self.msg_dim, self.memory_dim, self.time_dim),
            aggregator_module=ModifiedAggregator(version=self.aggregator_version),
            time_encoder=ModifiedTimeEncoder(self.time_dim, version=self.time_encoding_version)
        ).to(device)

        self.gnn = nn.ModuleList()
        for i in range(num_gnn_layers):
            self.gnn.append(
                GraphAttentionEmbedding(
                    in_channels=self.memory_dim,
                    out_channels=self.embedding_dim,
                    msg_dim=self.msg_dim,
                    time_enc=self.memory.time_enc
                ).to(device)
            )

        self.readout = ModifiedReadout(
            embedding_dim=self.embedding_dim,
            msg_dim=self.msg_dim,
            time_dim=self.time_dim,
            time_enc=self.memory.time_enc,
            version=self.readout_version
        ).to(device)

        self.neighbor_loader = LastNeighborLoader(
            self.num_nodes,
            size=num_neighbors,
            device=device
        )

        self.assoc = torch.empty(self.num_nodes, dtype=torch.long, device=device)

    def parameters(self, **kwargs):
        return set(self.memory.parameters()) | set(self.gnn.parameters()) | set(self.readout.parameters())

    def reset_parameters(self):
        pass

    def reset_state(self):
        self.memory.reset_state()
        self.neighbor_loader.reset_state()

    def detach(self):
        self.memory.detach()

    def forward(self, batch):
        batch = batch.to(self.device)
        n_id, edge_index, e_id = self.neighbor_loader(batch.n_id)
        self.assoc[n_id] = torch.arange(n_id.size(0), device=self.device)
        t = self.storage.t[e_id]
        msg = self.storage.msg[e_id]
        z, last_update = self.memory(n_id)
        for layer in self.gnn:
            z = layer(z, last_update, edge_index, t, msg)
        logits = self.readout(z[self.assoc[batch.src]], z[self.assoc[batch.dst]], batch.msg, batch.t)
        y = batch.y.float().unsqueeze(-1)
        loss = self.criterion(logits, y)
        self.memory.update_state(batch.src, batch.dst, batch.t, batch.msg)
        self.neighbor_loader.insert(batch.src, batch.dst)
        y_pred = logits.detach().cpu().sigmoid()
        y_true = batch.y.detach().cpu().float().unsqueeze(-1)
        return loss, y_pred, y_true


class ModifiedReadout(nn.Module):
    def __init__(self, embedding_dim, msg_dim, time_dim, time_enc, version=0):
        super().__init__()
        self.version = version
        self.lin_src = nn.Linear(embedding_dim, embedding_dim, bias=False)
        self.lin_dst = nn.Linear(embedding_dim, embedding_dim, bias=False)
        self.lin_msg = nn.Linear(msg_dim, embedding_dim, bias=False)
        self.lin_t = nn.Linear(time_dim, embedding_dim, bias=False)
        self.time_enc = time_enc

        if self.version == 0:
            self.lin_msg = self.cancel_layer(self.lin_msg)
            self.lin_t = self.cancel_layer(self.lin_t)
            self.time_enc = None
        elif self.version == 1:
            self.lin_t = self.cancel_layer(self.lin_t)
            self.time_enc = None
        elif self.version == 2:
            self.lin_msg = self.cancel_layer(self.lin_msg)
        elif self.version == 3:
            self.lin_src = self.cancel_layer(self.lin_src)
        elif self.version == 4:
            self.lin_dst = self.cancel_layer(self.lin_dst)
        elif self.version == 5:
            pass
        else:
            raise NotImplementedError

        self.lin_final = nn.Linear(embedding_dim, 1)

    def forward(self, z_src, z_dst, msg, t):
        if self.time_enc is not None:
            h = self.lin_src(z_src) + self.lin_dst(z_dst) + self.lin_msg(msg) + self.lin_t(self.time_enc(t.to(msg.dtype)))
        else:
            h = self.lin_src(z_src) + self.lin_dst(z_dst) + self.lin_msg(msg)
        h = h.relu()
        return self.lin_final(h)

    @staticmethod
    def cancel_layer(layer):
        layer.weight = nn.Parameter(torch.zeros_like(layer.weight))
        layer.weight.requires_grad = False
        return layer


class TGNMemory(nn.Module):
    r"""The Temporal Graph Network (TGN) memory model from the
    `"Temporal Graph Networks for Deep Learning on Dynamic Graphs"
    <https://arxiv.org/abs/2006.10637>`_ paper.

    .. note::

        For an example of using TGN, see `examples/tgn.py
        <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/
        tgn.py>`_.

    Args:
        num_nodes (int): The number of nodes to save memories for.
        raw_msg_dim (int): The raw message dimensionality.
        memory_dim (int): The hidden memory dimensionality.
        time_dim (int): The time encoding dimensionality.
        message_module (nn.Module): The message function which
            combines source and destination node memory embeddings, the raw
            message and the time encoding.
        aggregator_module (nn.Module): The message aggregator function
            which aggregates messages to the same destination into a single
            representation.
    """
    def __init__(self, num_nodes: int, raw_msg_dim: int, memory_dim: int,
                 time_dim: int, message_module: Callable,
                 aggregator_module: Callable, time_encoder: Callable):
        super().__init__()

        self.num_nodes = num_nodes
        self.raw_msg_dim = raw_msg_dim
        self.memory_dim = memory_dim
        self.time_dim = time_dim

        self.msg_s_module = message_module
        self.msg_d_module = copy.deepcopy(message_module)
        self.aggr_module = aggregator_module
        self.time_enc = time_encoder
        self.gru = nn.GRUCell(message_module.out_channels, memory_dim)

        self.register_buffer('memory', torch.empty(num_nodes, memory_dim))
        last_update = torch.empty(self.num_nodes, dtype=torch.long)
        self.register_buffer('last_update', last_update)
        self.register_buffer('_assoc', torch.empty(num_nodes,
                                                   dtype=torch.long))

        self.msg_s_store = {}
        self.msg_d_store = {}

        self.reset_parameters()

    def reset_parameters(self):
        r"""Resets all learnable parameters of the module."""
        if hasattr(self.msg_s_module, 'reset_parameters'):
            self.msg_s_module.reset_parameters()
        if hasattr(self.msg_d_module, 'reset_parameters'):
            self.msg_d_module.reset_parameters()
        if hasattr(self.aggr_module, 'reset_parameters'):
            self.aggr_module.reset_parameters()
        self.time_enc.reset_parameters()
        self.gru.reset_parameters()
        self.reset_state()

    def reset_state(self):
        """Resets the memory to its initial state."""
        zeros(self.memory)
        zeros(self.last_update)
        self._reset_message_store()

    def detach(self):
        """Detaches the memory from gradient computation."""
        self.memory.detach_()

    def forward(self, n_id: Tensor) -> Tuple[Tensor, Tensor]:
        """Returns, for all nodes :obj:`n_id`, their current memory and their
        last updated timestamp."""
        if self.training:
            memory, last_update = self._get_updated_memory(n_id)
        else:
            memory, last_update = self.memory[n_id], self.last_update[n_id]

        return memory, last_update

    def update_state(self, src: Tensor, dst: Tensor, t: Tensor,
                     raw_msg: Tensor):
        """Updates the memory with newly encountered interactions
        :obj:`(src, dst, t, raw_msg)`."""
        n_id = torch.cat([src, dst]).unique()

        if self.training:
            self._update_memory(n_id)
            self._update_msg_store(src, dst, t, raw_msg, self.msg_s_store)
            self._update_msg_store(dst, src, t, raw_msg, self.msg_d_store)
        else:
            self._update_msg_store(src, dst, t, raw_msg, self.msg_s_store)
            self._update_msg_store(dst, src, t, raw_msg, self.msg_d_store)
            self._update_memory(n_id)

    def _reset_message_store(self):
        i = self.memory.new_empty((0, ), dtype=torch.long)
        msg = self.memory.new_empty((0, self.raw_msg_dim))
        # Message store format: (src, dst, t, msg)
        self.msg_s_store = {j: (i, i, i, msg) for j in range(self.num_nodes)}
        self.msg_d_store = {j: (i, i, i, msg) for j in range(self.num_nodes)}

    def _update_memory(self, n_id: Tensor):
        memory, last_update = self._get_updated_memory(n_id)
        self.memory[n_id] = memory
        self.last_update[n_id] = last_update

    def _get_updated_memory(self, n_id: Tensor) -> Tuple[Tensor, Tensor]:
        self._assoc[n_id] = torch.arange(n_id.size(0), device=n_id.device)

        # Compute messages (src -> dst).
        msg_s, t_s, src_s, dst_s = self._compute_msg(n_id, self.msg_s_store,
                                                     self.msg_s_module)

        # Compute messages (dst -> src).
        msg_d, t_d, src_d, dst_d = self._compute_msg(n_id, self.msg_d_store,
                                                     self.msg_d_module)

        # Aggregate messages.
        idx = torch.cat([src_s, src_d], dim=0)
        msg = torch.cat([msg_s, msg_d], dim=0)
        t = torch.cat([t_s, t_d], dim=0)
        aggr = self.aggr_module(msg, self._assoc[idx], t, n_id.size(0))

        # Get local copy of updated memory.
        memory = self.gru(aggr, self.memory[n_id])

        # Get local copy of updated `last_update`.
        dim_size = self.last_update.size(0)
        last_update = scatter(t, idx, 0, dim_size, reduce='max')[n_id]

        return memory, last_update.long()

    def _update_msg_store(self, src: Tensor, dst: Tensor, t: Tensor,
                          raw_msg: Tensor, msg_store: TGNMessageStoreType):
        n_id, perm = src.sort()
        n_id, count = n_id.unique_consecutive(return_counts=True)
        for i, idx in zip(n_id.tolist(), perm.split(count.tolist())):
            msg_store[i] = (src[idx], dst[idx], t[idx], raw_msg[idx])

    def _compute_msg(self, n_id: Tensor, msg_store: TGNMessageStoreType,
                     msg_module: Callable):
        data = [msg_store[i] for i in n_id.tolist()]
        src, dst, t, raw_msg = list(zip(*data))
        src = torch.cat(src, dim=0)
        dst = torch.cat(dst, dim=0)
        t = torch.cat(t, dim=0)
        raw_msg = torch.cat(raw_msg, dim=0)
        t_rel = t - self.last_update[src]
        t_enc = self.time_enc(t_rel.to(raw_msg.dtype))

        msg = msg_module(self.memory[src], self.memory[dst], raw_msg, t_enc)

        return msg, t, src, dst

    def train(self, mode: bool = True):
        """Sets the module in training mode."""
        if self.training and not mode:
            # Flush message store to memory in case we just entered eval mode.
            self._update_memory(
                torch.arange(self.num_nodes, device=self.memory.device))
            self._reset_message_store()
        super().train(mode)


class IdentityMessage(nn.Module):
    def __init__(self, raw_msg_dim: int, memory_dim: int, time_dim: int):
        super().__init__()
        self.out_channels = raw_msg_dim + 2 * memory_dim + time_dim

    def forward(self, z_src: Tensor, z_dst: Tensor, raw_msg: Tensor,
                t_enc: Tensor):
        return torch.cat([z_src, z_dst, raw_msg, t_enc], dim=-1)


class ModifiedAggregator(nn.Module):
    def __init__(self, version="last"):
        super().__init__()
        self.version = version

    def forward(self, msg: Tensor, index: Tensor, t: Tensor, dim_size: int):
        if self.version == "mean":
            out = scatter(msg, index, dim=0, dim_size=dim_size, reduce='mean')
        elif self.version == "exp":
            if t.size(0) > 0:
                t = (t - t.min()) / (t.max() - t.min())
                weights = torch.exp(t)
                weights = (weights - weights.min()) / (weights.max() - weights.min())
                msg_ = weights.unsqueeze(-1) * msg
                out = scatter(msg_, index, dim=0, dim_size=dim_size, reduce='mean')
            else:
                out = scatter(msg, index, dim=0, dim_size=dim_size, reduce='mean')
        else:
            from torch_scatter import scatter_max
            _, argmax = scatter_max(t, index, dim=0, dim_size=dim_size)
            out = msg.new_zeros((dim_size, msg.size(-1)))
            mask = argmax < msg.size(0)  # Filter items with at least one entry.
            out[mask] = msg[argmax[mask]]

        return out


class ModifiedTimeEncoder(nn.Module):
    def __init__(self, out_channels: int, version="learn"):
        super().__init__()
        self.out_channels = out_channels
        self.lin = nn.Linear(1, out_channels)
        self.version = version

    def reset_parameters(self):
        if self.version == "fix":
            self.lin.weight = nn.Parameter(
                (torch.from_numpy(1 / 10 ** np.linspace(0, 9, self.out_channels, dtype=np.float32))).reshape(self.out_channels, -1))
            self.lin.bias = nn.Parameter(torch.zeros(self.out_channels))
            self.lin.weight.requires_grad = False
            self.lin.bias.requires_grad = False
        else:
            self.lin.reset_parameters()

    def forward(self, t: Tensor) -> Tensor:
        if self.type == "fix":
            with torch.no_grad():
                z = self.lin(t.view(-1, 1)).cos()
        else:
            z = self.lin(t.view(-1, 1)).cos()
        return z


class LastNeighborLoader(object):
    def __init__(self, num_nodes: int, size: int, device=None):
        self.size = size
        self.neighbors = torch.empty((num_nodes, size), dtype=torch.long, device=device)
        self.e_id = torch.empty((num_nodes, size), dtype=torch.long, device=device)
        self._assoc = torch.empty(num_nodes, dtype=torch.long, device=device)
        self.reset_state()

    def __call__(self, n_id: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
        neighbors = self.neighbors[n_id]
        nodes = n_id.view(-1, 1).repeat(1, self.size)
        e_id = self.e_id[n_id]

        # Filter invalid neighbors (identified by `e_id < 0`).
        mask = e_id >= 0
        neighbors, nodes, e_id = neighbors[mask], nodes[mask], e_id[mask]

        # Relabel node indices.
        n_id = torch.cat([n_id, neighbors]).unique()
        self._assoc[n_id] = torch.arange(n_id.size(0), device=n_id.device)
        neighbors, nodes = self._assoc[neighbors], self._assoc[nodes]

        return n_id, torch.stack([neighbors, nodes]), e_id

    def insert(self, src: Tensor, dst: Tensor):
        # Inserts newly encountered interactions into an ever-growing
        # (undirected) temporal graph.

        # Collect central nodes, their neighbors and the current event ids.
        neighbors = torch.cat([src, dst], dim=0)
        nodes = torch.cat([dst, src], dim=0)
        e_id = torch.arange(self.cur_e_id, self.cur_e_id + src.size(0),
                            device=src.device).repeat(2)
        self.cur_e_id += src.numel()

        # Convert newly encountered interaction ids so that they point to
        # locations of a "dense" format of shape [num_nodes, size].
        nodes, perm = nodes.sort()
        neighbors, e_id = neighbors[perm], e_id[perm]

        n_id = nodes.unique()
        self._assoc[n_id] = torch.arange(n_id.numel(), device=n_id.device)

        dense_id = torch.arange(nodes.size(0), device=nodes.device) % self.size
        dense_id += self._assoc[nodes].mul_(self.size)

        dense_e_id = e_id.new_full((n_id.numel() * self.size, ), -1)
        dense_e_id[dense_id] = e_id
        dense_e_id = dense_e_id.view(-1, self.size)

        dense_neighbors = e_id.new_empty(n_id.numel() * self.size)
        dense_neighbors[dense_id] = neighbors
        dense_neighbors = dense_neighbors.view(-1, self.size)

        # Collect new and old interactions...
        e_id = torch.cat([self.e_id[n_id, :self.size], dense_e_id], dim=-1)
        neighbors = torch.cat(
            [self.neighbors[n_id, :self.size], dense_neighbors], dim=-1)

        # And sort them based on `e_id`.
        e_id, perm = e_id.topk(self.size, dim=-1)
        self.e_id[n_id] = e_id
        self.neighbors[n_id] = torch.gather(neighbors, 1, perm)

    def reset_state(self):
        self.cur_e_id = 0
        self.e_id.fill_(-1)


class GraphAttentionEmbedding(nn.Module):
    def __init__(self, in_channels, out_channels, msg_dim, time_enc):
        super().__init__()
        self.time_enc = time_enc
        edge_dim = msg_dim + time_enc.out_channels
        self.conv = TransformerConv(in_channels, out_channels // 2, heads=2,
                                    dropout=0.1, edge_dim=edge_dim)

    def forward(self, x, last_update, edge_index, t, msg):
        rel_t = last_update[edge_index[0]] - t
        rel_t_enc = self.time_enc(rel_t.to(x.dtype))
        edge_attr = torch.cat([rel_t_enc, msg], dim=-1)
        return self.conv(x, edge_index, edge_attr)



